import os
import json
import math
import h5py
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import time
from torch.utils.data import TensorDataset, DataLoader
from lstm_model_standalone import LSTMModel, StackedLSTMWithLayerNorm


# Utility: build test sequences from HDF5 

def build_test_sequences(h5_path, timesteps):
    """
    Slide a window of length = timesteps over each sample in the HDF5 file.
    HDF5 layout is assumed like:
        group[...] with datasets: 'scaled_parameters', 'normalized_current'
    """
    Xs, ys = [], []
    with h5py.File(h5_path, 'r') as f:
        for key in f.keys():
            grp = f[key]
            feats = grp['scaled_parameters'][:]      # shape (T, features)
            targ  = grp['normalized_current'][:]     # shape (T,)
            # slide a window of length=timesteps
            for j in range(len(targ) - timesteps):
                Xs.append(feats[j:j+timesteps])
                ys.append(targ[j+1:j+1+timesteps])
    return np.stack(Xs).astype(np.float32), np.stack(ys).astype(np.float32)


#  Core evaluation

def evaluate_run(
    run_dir,
    test_h5_path,
    which_ckpt="best_by_nrmse",   # or "best_by_mse" or "last"
    batch_size=256,
    num_plot_windows=10
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Eval] Device: {device}")

    #  Load hparams from training run 
    hparams_path = os.path.join(run_dir, "hparams.json")
    if not os.path.exists(hparams_path):
        raise FileNotFoundError(f"hparams.json not found in run_dir: {hparams_path}")

    with open(hparams_path, "r") as f:
        hp = json.load(f)

    model_type  = hp.get("model_type", "layernorm")  # default if older runs miss this field
    hidden_size = hp["hidden_size"]
    num_layers  = hp["num_layers"]
    dropout     = hp["dropout"]
    timesteps   = hp.get("timesteps", 160)

    print(f"[Eval] Loaded hparams from {hparams_path}")
    print(f"        model_type={model_type}, hidden_size={hidden_size}, "
          f"num_layers={num_layers}, dropout={dropout}, timesteps={timesteps}")

    #  Build test dataset 
    X_test, y_test = build_test_sequences(test_h5_path, timesteps)
    print(f"[Eval] Test sequences: X={X_test.shape}, y={y_test.shape}")

    test_ds = TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test))
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=True)

    # NRMSE denominator = std of test targets
    y_test_flat = y_test.reshape(-1)
    nrmse_denom = y_test_flat.std().item()
    if nrmse_denom < 1e-8:
        nrmse_denom = 1.0
    print(f"[Eval] NRMSE denominator (std of test targets): {nrmse_denom:.6f}")

    #  Build the correct model architecture 
    input_size = X_test.shape[2]
    if model_type == "layernorm":
        model = StackedLSTMWithLayerNorm(
        input_size=input_size,
        hidden_size=hidden_size,
        num_layers=num_layers,
        dropout=dropout
    ).to(device)

    elif model_type == "basic":
        model = LSTMModel(
        input_size=input_size,
        hidden_size=hidden_size,
        num_layers=num_layers,
        dropout=dropout
    ).to(device)
    else:
        raise ValueError(f"Unknown model_type in hparams: {model_type}")

    #  Locate and load checkpoint 
    if which_ckpt not in ["best_by_nrmse", "best_by_mse", "last"]:
        raise ValueError("which_ckpt must be one of: best_by_nrmse, best_by_mse, last")

    ckpt_path = os.path.join(run_dir, f"{which_ckpt}.pth")
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    print(f"[Eval] Loading checkpoint: {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location=device)
    # accept either full ckpt dict or raw state dict
    state_dict = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt
    model.load_state_dict(state_dict)
    model.eval()

    #  Evaluate 
    mse_crit = nn.MSELoss()
    mae_crit = nn.L1Loss()

    total_mse = 0.0
    total_mae = 0.0
    sse_test  = 0.0
    n_test    = 0.0
    n_windows = 0

    all_preds  = []
    all_truths = []

    with torch.no_grad():
        for Xb, yb in test_loader:
            Xb = Xb.to(device)
            yb = yb.to(device)

            out = model(Xb)  # (B, T)
            mse_val = mse_crit(out, yb).item()
            mae_val = mae_crit(out, yb).item()

            total_mse += mse_val * Xb.size(0)
            total_mae += mae_val * Xb.size(0)

            err = out - yb
            sse_test += (err * err).sum().item()
            n_test   += err.numel()

            n_windows += Xb.size(0)

            all_preds.append(out.cpu().numpy())
            all_truths.append(yb.cpu().numpy())

    avg_mse = total_mse / max(n_windows, 1)
    avg_mae = total_mae / max(n_windows, 1)
    rmse_test = math.sqrt(max(sse_test / max(n_test, 1), 0.0) + 1e-12)
    nrmse_test = rmse_test / max(nrmse_denom, 1e-12)

    print(f"[Eval] Test MSE  : {avg_mse:.10f}")
    print(f"[Eval] Test MAE  : {avg_mae:.10f}")
    print(f"[Eval] Test NRMSE: {nrmse_test:.10f}")

    preds  = np.vstack(all_preds)
    truths = np.vstack(all_truths)

    #  nMAE per-window 
    def compute_nmae_new(preds, truths, eps=1e-12):
        """
        New nMAE: per-window MAE normalized by 2 * max amplitude of (truth, pred),
        then expressed in percent.
        """
        denom = np.maximum(np.max(np.abs(truths), axis=1),
                           np.max(np.abs(preds), axis=1))
        denom = np.clip(denom, eps, None)
        mae_win = np.mean(np.abs(preds - truths), axis=1)
        nmae_win = mae_win / (2.0 * denom)
        return nmae_win * 100.0  # %

    nmae_new_per_window = compute_nmae_new(preds, truths)
    nmae_new_mean = nmae_new_per_window.mean()

    print(f"[Eval] New nMAE (mean over windows, normalized): {nmae_new_mean:.3f} %")
    print(f"[Eval] New nMAE range: {nmae_new_per_window.min():.3f} % – {nmae_new_per_window.max():.3f} %")

    #  Plots & metrics saved into run_dir/eval 
    eval_dir = os.path.join(run_dir, "eval")
    os.makedirs(eval_dir, exist_ok=True)

    # Save text metrics
    metrics_text = (
        f"Test MSE: {avg_mse:.10f}\n"
        f"Test MAE: {avg_mae:.10f}\n"
        f"Test NRMSE: {nrmse_test:.10f}\n"
        f"New nMAE (mean over windows, normalized): {nmae_new_mean:.3f} %\n"
    )
    with open(os.path.join(eval_dir, "metrics.txt"), "w") as f:
        f.write(metrics_text)

    #  per-window MAE/NRMSE arrays
    mae_list   = []
    nrmse_list = []

    for i in range(len(preds)):
        err_i  = truths[i] - preds[i]
        mae_i  = np.abs(err_i).mean()
        rmse_i = np.sqrt(np.mean(err_i**2))
        nrmse_i = rmse_i / max(nrmse_denom, 1e-12)

        mae_list.append(mae_i)
        nrmse_list.append(nrmse_i)

    mae_array   = np.array(mae_list)
    nrmse_array = np.array(nrmse_list)

    # --- Sorted MAE and nMAE for plots ---
    sorted_indices_mae = np.argsort(mae_array)
    sorted_mae = mae_array[sorted_indices_mae]

    sorted_indices_nmae = np.argsort(nmae_new_per_window)
    sorted_nmae_new = nmae_new_per_window[sorted_indices_nmae]

    # Plot sorted MAE
    plt.figure(figsize=(8, 4))
    plt.plot(sorted_mae, label="Sorted MAE")
    plt.xlabel("Sample Index (sorted by MAE)")
    plt.ylabel("MAE")
    plt.title("MAE per Test Sample (Sorted)")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(eval_dir, "sorted_mae_plot.png"), dpi=300)
    plt.close()

    # Plot sorted nMAE
    plt.figure(figsize=(8, 4))
    plt.plot(sorted_nmae_new, label="Sorted nMAE (%)")
    plt.xlabel("Sample Index (sorted by nMAE)")
    plt.ylabel("nMAE (%)")
    plt.title("nMAE per Test Sample (Sorted)")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(eval_dir, "sorted_nmae_plot.png"), dpi=300)
    plt.close()

    # Best / worst windows by nMAE 
    time_axis = np.arange(timesteps) * 0.1  # adjust if your dt is different

    worst_idx_nmae  = int(np.argmax(nmae_new_per_window))
    best_idx_nmae   = int(np.argmin(nmae_new_per_window))
    #worst_idx_nrmse = int(np.argmax(nrmse_array))

    def plot_window(idx, title_prefix, filename):
        plt.figure(figsize=(8, 4))
        plt.plot(time_axis, truths[idx], label="Ground Truth", linewidth=2)
        plt.plot(time_axis, preds[idx],  label="Prediction",   linewidth=1.5, linestyle="--")

        err = truths[idx] - preds[idx]
        mse_val  = (err**2).mean()
        rmse_val = np.sqrt(max(mse_val, 0.0))
        nrmse_val = rmse_val / max(nrmse_denom, 1e-12)
        mae_val  = np.abs(err).mean()
        nmae_val = nmae_new_per_window[idx]

        plt.title(
            f"{title_prefix} — Window #{idx} "
            f"MSE={mse_val:.3e}, RMSE={rmse_val:.3e}, "
            f"NRMSE={nrmse_val:.3e}, MAE={mae_val:.3e}, nMAE={nmae_val:.2f}%"
        )
        plt.xlabel("Time (s)")
        plt.ylabel("Normalized Current")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(os.path.join(eval_dir, filename), dpi=300)
        plt.close()

    plot_window(worst_idx_nmae,  "Worst Prediction by nMAE",  "worst_by_nmae.png")
    plot_window(best_idx_nmae,   "Best Prediction by nMAE",   "best_by_nmae.png")
    #plot_window(worst_idx_nrmse, "Worst Prediction by NRMSE", "worst_by_nrmse.png")

    # Random sample windows for inspection
    rng = np.random.default_rng(0)
    n_total = len(preds)
    k = min(num_plot_windows, n_total)
    sample_idxs = rng.choice(n_total, size=k, replace=False)

    for idx in sample_idxs:
        plot_window(idx, f"Random sample {idx}", f"sample_{idx}.png")

    print(f"[Eval] Saved metrics and plots to {eval_dir}")


#CLI wrapper 

if __name__ == "__main__":
    import argparse

    p = argparse.ArgumentParser()
    p.add_argument("--run_dir",  type=str, required=True,
                   help="Path to training run directory (where hparams.json and *.pth live)")
    p.add_argument("--test_data", type=str, required=True,
                   help="Path to test HDF5 file")
    p.add_argument("--which_ckpt", type=str, default="best_by_nrmse",
                   choices=["best_by_nrmse", "best_by_mse", "last"],
                   help="Which checkpoint file in run_dir to evaluate")
    p.add_argument("--batch_size", type=int, default=256)
    p.add_argument("--num_plot_windows", type=int, default=10)
    args = p.parse_args()

    evaluate_run(
        run_dir=args.run_dir,
        test_h5_path=args.test_data,
        which_ckpt=args.which_ckpt,
        batch_size=args.batch_size,
        num_plot_windows=args.num_plot_windows
    )

